In [6]:
from tensor2tensor.models import transformer
from tensor2tensor.layers import common_layers
import tensorflow as tf

In [None]:
class AttentionCNN(transformer.Transformer):
    """Sequence-CNN with interwoven self-attention."""
    # Gated Transcription Factor cOnvolutions (GTFO)
    
    def bottom(self, features):
        # ids ==> DNA embeddings
        # assume there are 5 IDs: ACTGN
        features["inputs"] = common_layers.embedding(
            features["inputs"],
            vocab_size=5,
            dense_size=self._hparams.hidden_size)
        return features
    
    def top(self, body_output, unused_features):
        # body_output ==> logits [batch_size, target_length]
        return tf.reduce_mean(body_output, axis=-1)
        
    def loss(self, logits, features):
        # logits ==> loss
        labels = features["targets"]
        loss_num = tf.losses.sigmoid_cross_entropy(
            labels, logits, label_smoothing=hparams.label_smoothing)
        loss_denom = 1.0
        return (loss_num, loss_denom)
    
    def body(self, features):
        hparams = self._hparams
        
        inputs = features["inputs"]
        target_space = features["target_space_id"]

        # force these settings
        full_att_conv = ("dot_product", "conv_relu_conv")
        local_att_fc = ("local_unmasked", "dense_relu_dense")
        
        # by default conv ffn uses a kernel of size 3
        # by default local att uses a kernel size of 128
        hparams.ffn_layer, hparams.self_attention_type = full_att_conv
        body_output, _ = self.encode(inputs, target_space, hparams, features=features)
        
        targets = features["targets"]
        target_length = targets.shape[1]
        
        # dense connection along input_length
        layer_input = tf.transpose(body_output, [0, 2, 1]) # [batch_size, hidden_dim, input_length]
        layer_output = common_layers.dense(layer_input, target_length, activation=tf.nn.relu)
        
        body_output = common_layers.layer_postprocess(layer_input, layer_output, hparams)
        body_output = tf.transpose(body_output, [0, 2, 1])  # [batch_size, target_length, hidden_dim]

        return body_output