In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm

import tensorflow as tf
import keras
from keras import layers
from keras.utils import np_utils
from keras.preprocessing.sequence import skipgrams

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
)
from sklearn.model_selection import (
    train_test_split as tts,
    StratifiedKFold
)
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.multiclass import type_of_target

import seaborn as sns

from Bio import SeqIO
from Bio.Seq import Seq

In [34]:
data = pd.read_csv('../data/pro_nonpro.csv')

x = data.Seq
y = data.Level

In [35]:
kmer_size = 4
AUTOTUNE = tf.data.AUTOTUNE
SEED = 42

In [36]:
def get_kmers(sequences, kmer=4):
    return_seqs = sequences.copy()
    if kmer <= 1:
        raise ValueError("kmer size must be greater than 1")
    for seq_index, seq in sequences.iteritems():
        kmer_list = []
        for let_index, let in enumerate(seq[:-kmer + 1]):
            kmer_list.append(seq[let_index:let_index + kmer])
        return_seqs[seq_index] = kmer_list
    return return_seqs


In [37]:
def generate_training_data(sequences, window_size, num_ns, vocab_size, seed):
  # Elements of each training example are appended to these lists.
  targets, contexts, labels = [], [], []

  # Build the sampling table for `vocab_size` tokens.
  sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)

  # Iterate over all sequences (sentences) in the dataset.
  for sequence in tqdm.tqdm(sequences):

    # Generate positive skip-gram pairs for a sequence (sentence).
    positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          sequence,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=window_size,
          negative_samples=0)

    # Iterate over each positive skip-gram pair to produce training examples
    # with a positive context word and negative samples.
    for target_word, context_word in positive_skip_grams:
      context_class = tf.expand_dims(
          tf.constant([context_word], dtype="int64"), 1)
      negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
          true_classes=context_class,
          num_true=1,
          num_sampled=num_ns,
          unique=True,
          range_max=vocab_size,
          seed=seed,
          name="negative_sampling")

      # Build context and label vectors (for one target word)
      negative_sampling_candidates = tf.expand_dims(
          negative_sampling_candidates, 1)

      context = tf.concat([context_class, negative_sampling_candidates], 0)
      label = tf.constant([1] + [0]*num_ns, dtype="int64")

      # Append each element from the training example to global lists.
      targets.append(target_word)
      contexts.append(context)
      labels.append(label)

  return targets, contexts, labels

In [38]:
class SkipgramVectorizer:
    def __init__(self):
        self.vocab = {'<pad>': 0}
        self.index = 1
        self.vocab_size = 0
        self.inverse_vocab = self.reverse_vocab(self.vocab)
    
    def reverse_vocab(self, vocab):
        return {index: token for token, index in vocab.items()}
    
    def update_vocab(self, sentences: list):
        self.inverse_vocab = self.reverse_vocab(self.vocab)
        for sentence in sentences:
            tokens = list(sentence.lower().split())
            for token in tokens:
                if token not in self.vocab:
                    self.vocab[token] = self.index
                    self.index += 1
        self.vocab_size = len(self.vocab)
        self.inverse_vocab = self.reverse_vocab(self.vocab)
    
    def vectorize(self, sentences: list): 
        return [[self.vocab[word] for word in sentence.lower().split()] for sentence in sentences]
    
    def gen_skipgrams(self, sentences, win_size=2):
        return [skipgrams(
            sentence,
            vocabulary_size=self.vocab_size,
            window_size=win_size,
            negative_samples=0
        )[0] for sentence in sentences]

    def __repr__(self):
        return str(self.vocab)


In [39]:
my_vectorizer = SkipgramVectorizer()
my_vectorizer.update_vocab('The wide road shimmered in the hot sun')
my_vectorizer.vocab
vector = my_vectorizer.vectorize("The wide road shimmered in the hot sun")
my_skipgrams = my_vectorizer.gen_skipgrams(vector)

In [40]:
vectorize_layer = layers.TextVectorization(
    standardize=lambda x: x,
    max_tokens=300,
    output_mode='int',
    output_sequence_length=81
)

In [41]:
kmer_sequences = get_kmers(x)
for index, val in kmer_sequences.iteritems():
    kmer_sequences[index] = ' '.join(val)

In [42]:
seq_tf_ds = tf.data.Dataset.from_tensor_slices(kmer_sequences)
vectorize_layer.adapt(seq_tf_ds.batch(1024))
vocab = vectorize_layer.get_vocabulary()
vocab_size = len(vocab)
text_vector_ds = seq_tf_ds.batch(1024).prefetch(AUTOTUNE).map(vectorize_layer).unbatch()
sequences = list(text_vector_ds.as_numpy_iterator())

In [11]:
for seq in sequences:
    print(f"{seq} => {[vocab[i] for i in seq]}")

[254 139  23 125 183 253 196 172 212  33  12  52  28  18 103  44 200 136
 128 207  26   2   6  25  28  18  13  20 105 170 172  57   3  16  18  13
   2   2  20 105 128 123  78 154  48  91  37  79  67 107  26   2   2  20
  69 178  75  11  60  90 121  25  79  67 100 101  85 173  90   8   6  77
  35 135 136  95 171  90   0   0   0] => ['TAGA', 'AGAT', 'GATG', 'ATGT', 'TGTC', 'GTCC', 'TCCT', 'CCTT', 'CTTG', 'TTGA', 'TGAT', 'GATT', 'ATTA', 'TTAA', 'TAAC', 'AACA', 'ACAC', 'CACC', 'ACCA', 'CCAA', 'CAAA', 'AAAA', 'AAAT', 'AATT', 'ATTA', 'TTAA', 'TAAA', 'AAAC', 'AACC', 'ACCT', 'CCTT', 'CTTT', 'TTTT', 'TTTA', 'TTAA', 'TAAA', 'AAAA', 'AAAA', 'AAAC', 'AACC', 'ACCA', 'CCAG', 'CAGG', 'AGGC', 'GGCA', 'GCAT', 'CATT', 'ATTC', 'TTCA', 'TCAA', 'CAAA', 'AAAA', 'AAAA', 'AAAC', 'AACG', 'ACGG', 'CGGC', 'GGCG', 'GCGA', 'CGAA', 'GAAT', 'AATT', 'ATTC', 'TTCA', 'TCAT', 'CATC', 'ATCG', 'TCGA', 'CGAA', 'GAAA', 'AAAT', 'AATC', 'ATCA', 'TCAC', 'CACC', 'ACCG', 'CCGA', 'CGAA', '', '', '']
[ 17  71 106   8   2   6  43  

In [43]:
targets, contexts, labels = generate_training_data(
    sequences=sequences,
    window_size=4,
    num_ns=20,
    vocab_size=vocab_size,
    seed=SEED
)

100%|██████████| 6764/6764 [02:33<00:00, 44.05it/s]


In [44]:
targets = np.array(targets)
contexts = np.array(contexts)[:,:,0]
labels = np.array(labels)

In [45]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

In [46]:
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
print(dataset)

<PrefetchDataset element_spec=((TensorSpec(shape=(1024,), dtype=tf.int64, name=None), TensorSpec(shape=(1024, 21), dtype=tf.int64, name=None)), TensorSpec(shape=(1024, 21), dtype=tf.int64, name=None))>


In [47]:
class Word2Vec(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim):
    super(Word2Vec, self).__init__()
    self.target_embedding = layers.Embedding(vocab_size,
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding")
    self.context_embedding = layers.Embedding(vocab_size,
                                       embedding_dim,
                                       input_length=4+1)

  def call(self, pair):
    target, context = pair
    # target: (batch, dummy?)  # The dummy axis doesn't exist in TF2.7+
    # context: (batch, context)
    if len(target.shape) == 2:
      target = tf.squeeze(target, axis=1)
    # target: (batch,)
    word_emb = self.target_embedding(target)
    # word_emb: (batch, embed)
    context_emb = self.context_embedding(context)
    # context_emb: (batch, context, embed)
    dots = tf.einsum('be,bce->bc', word_emb, context_emb)
    # dots: (batch, context)
    return dots

In [48]:
embedding_dim = 128

word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(
    optimizer='adam',
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [49]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs')

In [50]:
word2vec.fit(dataset, epochs=200, callbacks=[tensorboard_callback])

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Epoch 65/200
Epoch 66/200
Epoch 67/200
Epoch 68/200
Epoch 69/200
Epoch 70/200
Epoch 71/200
Epoch 72/200
Epoch 73/200
Epoch 74/200
Epoch 75/200
Epoch 76/200
Epoch 77/200
Epoch 78

KeyboardInterrupt: 

In [55]:
%load_ext tensorboard
%tensorboard --logdir logs
%matplotlib inline

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 21128), started 0:00:18 ago. (Use '!kill 21128' to kill it.)