In [1]:
import tensorflow as tf
from  tensorflow import keras

import utils

Drop_rate=0.1

In [None]:
class Multihead(keras.layers.Layer):
    def __init__(self,model_dim,n_head,Drop_rate):
        super(Multihead, self).__init__()
        self.model_dim=model_dim
        self.n_head=n_head
        self.head_dim=self.model_dim//self.n_head
        self.wq=keras.layers.Dense(self.n_head*self.head_dim)
        self.wk=keras.layers.Dense(self.n_head*self.head_dim)
        self.wv=keras.layers.Dense(self.n_head*self.head_dim)
        self.dense=keras.layers.Dense(self.model_dim)
        self.drop=keras.layers.Dropout(rate=Drop_rate)
    def call(self,q,k,v,mask,training):
        o=self.scale_dot(q,k,v,mask)
        o=self.dense(o)
        return self.drop(o,training=training)
    def scale_dot(self,q,k,v,mask):
        num_k=tf.cast(k.shape[-1],tf.float32)
        _q=self.wq(q)
        _k=self.wk(k)
        _v=self.wv(v)
        _q=self.reshape(_q)
        _k=self.reshape(_k)
        _v=self.reshape(_v)
        score=tf.matmul(_q,_k,transpose_b=True)/tf.sqrt(num_k+1e-8) #n,head,step,step
        if mask is not None:
            score+=mask*-1e9
        attention=tf.nn.softmax(score,axis=-1)
        context=tf.matmul(attention,_v) #n head step dim
        context=tf.transpose(context,perm=[0,2,1,3])
        return tf.reshape(context,(context.shape[0],context.shape[1],-1))
    def reshape(self,x):
        x=tf.reshape(x,(x.shape[0],x.shape[1],self.n_head,self.head_dim))
        return tf.transpose(x,perm=[0,2,1,3])

class PositionWiseFFN(keras.layers.Layer):
    def __init__(self,model_dim):
        super(PositionWiseFFN, self).__init__()
        self.dense=keras.layers.Dense(model_dim*4)
        self.dense1=keras.layers.Dense(model_dim)
    def call(self,x):
        o=self.dense(x,activation=keras.activations.relu)
        return self.dense1(o)

class EncoderLayer(keras.layers.Layer):
    def __init__(self,model_dim,n_head,Drop_rate):
        super(EncoderLayer, self).__init__()
        self.mul=Multihead(model_dim,n_head,Drop_rate)
        self.ffn=PositionWiseFFN(model_dim)
        self.drop=keras.layers.Dropout(rate=Drop_rate)
        self.l=[keras.layers.LayerNormalization(axis=-1) for _ in range(2)]

    def call(self,x,mask,training):
        o1=self.mul.call(x,x,x,mask,training)
        o1=self.l[0](o1+x)
        o2=self.drop(self.ffn(o1),training=training)
        return self.l[1](o1+o2)

class Encoder(keras.layers.Layer):
    def __init__(self,n_layer,model_dim,n_head,Drop_rate):
        self.n_layer=n_layer
        super(Encoder, self).__init__()
        self.l=[EncoderLayer(model_dim,n_head,Drop_rate) for _ in range(n_layer)]
    def call(self, x,mask,training):
        for i in range(self.n_layer):
            x=self.l[i].call(x,mask,training)
        return x

class DecoderLayer(keras.layers.Layer):
    def __init__(self,model_dim,n_head,Drop_rate):
        super(DecoderLayer, self).__init__()
        self.mh=[Multihead(model_dim,n_head,Drop_rate) for _ in range(2)]
        self.ln=[keras.layers.LayerNormalization(axis=-1) for _ in range(3)]
        self.ffn=PositionWiseFFN(model_dim)
        self.drop=keras.layers.Dropout(Drop_rate)
    def call(self,x,y,xz_pad_mask,yz_look_ahead_mask,training):
        o1=self.mh[0].call(y,y,y,yz_look_ahead_mask,training)
        o2=self.ln[0](o1+y)
        o3=self.mh[1].call(o2,x,x,xz_pad_mask)
        o4=self.ln[1](o3+o2)
        o5=self.drop(self.ffn(o4))
        o=self.ln[2](o5+o4)
        return o

class Decoder(keras.layers.Layer):
    def __init__(self,model_dim,n_head,Drop_rate,n_layer):
        self.n_layer=n_layer
        super(Decoder, self).__init__()
        self.l=[DecoderLayer(model_dim,n_head,Drop_rate) for _ in range(n_layer)]

    def call(self,x,y,xz_pad_mask,yz_look_ahead_mask,training):
        for i in range(self.n_layer):
            y=self.l[i].call(x,y,xz_pad_mask,yz_look_ahead_mask,training)
        return y

import numpy as np

class PositionEmbedding(keras.layers.Layer):
    def __init__(self,max_len,model_dim,n_vocab):
        super(PositionEmbedding, self).__init__()
        self.max_len=max_len
        self.model_dim=model_dim
        self.emb=keras.layers.Embedding(n_vocab,model_dim)
    def position(self):
        pos=np.arange(self.max_len)[:,None]
        q=[np.power(10000,2*i) for i in range(self.model_dim)]
        q=q[None,:]
        matrix=pos/q
        matrix[:,0::2]=np.sin(matrix[:,0::2])
        matrix[:,1::2]=np.cos(matrix[:,1::2])
        matrix=tf.constant(matrix[None,:,:])
        return matrix
    def call(self,x):
        return self.position()+self.emb(x)

class Transformer(keras.Model):

    def __init__(self,n_head,model_dim,Drop_rate,n_layer,n_vocab,max_len,padding_idx):
        super(Transformer, self).__init__()
        self.padding_idx=padding_idx
        self.max_len=max_len
        self.encoder=Encoder(n_layer,model_dim,n_head,Drop_rate)
        self.decoder=Decoder(model_dim,n_head,Drop_rate,n_layer)
        self.emb=PositionEmbedding(max_len,model_dim,n_vocab)
        self.o=keras.layers.Dense(n_vocab)
        self.loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True,reduction='none')
        self.opt=keras.optimizers.Adam(0.001)

    def call(self,x,y,training):
        emb_x=self.emb(x)
        emb_y=self.emb(y)
        pad_mask=self.pad_mask(x)
        xz=self.encoder.call(emb_x,pad_mask,training)
        yz=self.decoder.call(xz,emb_y,pad_mask,self.look_ahead_mask(y),training)
        return self.o(yz)

    def step(self,x,y,training):
        with tf.GradientTape() as tape:
            logit=self.call(x,y[:,:-1],training)
            bool_mask=tf.math.not_equal(y,self.padding_idx)
            loss=tf.reduce_mean(tf.boolean_mask(tf.losses(logit,y[:,1:]),bool_mask))
        grads=tape.gradient(loss,self.trainable_variables)
        self.opt.apply_gradients(zip(grads,self.trainable_variables))
        return loss

    def pad_mask(self,seq):
        mask=tf.math.equal(seq,self.padding_idx)
        mask=tf.cast(mask,tf.float32)
        return mask[:,None,None,:]

    def look_ahead_mask(self,seq):
        m=1-tf.linalg.band_part(tf.ones(self.max_len,self.max_len),-1,0)
        judge=tf.math.equal(seq,self.padding_idx)
        mask=tf.where(judge[:,None,None,:],1,m[None,None,:,:])
        return mask

    def reference(self,x,v2i,i2v):
        y=[[v2i['<GO>']] for _ in range(len(x))]
        y=utils.pad_zero(y,self.max_len)
        x=utils.pad_zero(x,self.max_len)
        idx=0
        emb_x=self.emb(x)
        xz=self.encoder.call(emb_x,self.pad_mask(x),training=False)
        while True:
            yz=self.decoder.call(xz,self.emb(y),self.pad_mask(x),self.look_ahead_mask(y))
            index=np.argmax(yz,axis=1)
            y[:,idx]=index
            idx+=1
            if idx==self.max_len:
                break
        return ["".join([i2v[item] for item in y[j,1:]]) for j in range(self.max_len)]